#include "grain_growth.h"
#include "png.h"
#include "cstring"
#include "iostream"
#include <fstream>
#include <sstream>
#include <string>
#include <algorithm>
#include <iterator>
#include <cassert>
#include <iomanip>
#define PNG_BYTES_TO_CHECK 4


void calculate_mse_loss(valueType* eta, valueType* eta_ref, valueType* dloss, uint size) {
    for (uint i = 0; i < size; ++i) {
        dloss[i] = 2 * (eta[i] - eta_ref[i]);
    }
}

valueType sum_mse_loss(valueType* eta, valueType* eta_ref, uint size) {
    valueType sum = 0;
    for (uint i = 0; i < size; ++i) {
        sum += (eta[i] - eta_ref[i]) * (eta[i] - eta_ref[i]);
    }
    return sum;
}

valueType sum_mtx(valueType* eta, uint size) {
    valueType sum = 0;
    for (uint i = 0; i < size; ++i) {
        sum += eta[i];
    }
    return sum;
}

valueType *read_data_file(const char *filename, uint &Nx, uint &Ny, uint &n_grains, uint &n_step)
{
    FILE *inp = fopen(filename, "r");
    fscanf(inp, "%u,%u,%u,%u", &Nx, &Ny, &n_grains, &n_step);
    std::cout << "header: " << Nx << ", " << Ny << ", " << n_grains << ", " << n_step << std::endl;
    valueType *mtx = new valueType[Nx * Ny * n_grains * n_step];
    for (uint n = 0; n < n_step; ++n) {    
        for (uint pg = 0; pg < n_grains; ++pg) {
            for (uint x = 0; x < Nx; ++x) {
                fscanf(inp, "%lf", mtx + n*n_grains*Nx*Ny + pg*Nx*Ny + x*Ny);
                for (uint y = 1; y < Ny; ++y)
                    fscanf(inp, ",%lf", mtx + n*n_grains*Nx*Ny + pg*Nx*Ny + x*Ny + y);
            }
            if (n==0) {
                valueType sum = 0;
                valueType *eta = mtx + pg*Nx*Ny;
                for (uint i = 0; i < 128*128; ++i) {
                    sum += eta[i];
                }
                std::cout << "sum of mtx first frame eta_" << pg+1 << ": " << sum << std::endl;
            }
        }
        if (n==0)
            std::cout << "sum of mtx first frame: " << sum_mtx(mtx, 128*128*2) << std::endl;
    }
    fclose(inp);

    return mtx;
}

valueType *read_init_state(const char *filename, uint &Nx, uint &Ny, uint &n_grains, uint &n_step)
{
    FILE *inp = fopen(filename, "r");
    fscanf(inp, "%u,%u,%u,%u", &Nx, &Ny, &n_grains, &n_step);
    std::cout << "header: " << Nx << ", " << Ny << ", " << n_grains << ", " << n_step << std::endl;
    valueType *mtx = new valueType[Nx * Ny * n_grains];
    for (uint pg = 0; pg < n_grains; ++pg)
    {
        for (uint x = 0; x < Nx; ++x)
        {
            fscanf(inp, "%lf", mtx + pg * Nx * Ny + x * Ny);
            for (uint y = 1; y < Ny; ++y)
                fscanf(inp, ",%lf", mtx + pg * Nx * Ny + x * Ny + y);
        }
    }

    for (uint pg = 0; pg < n_grains; ++pg) {
        valueType sum = 0;
        valueType *eta = mtx + pg*Nx*Ny;
        for (uint i = 0; i < Nx*Ny; ++i) {
            sum += eta[i];
            std::cout << "read eta_" << pg+1 << "_" << i << ": " << eta[i] << std::endl;
        }
        std::cout << "sum of mtx first frame eta_" << pg+1 << ": " << sum << std::endl;
    }
    fclose(inp);
    return mtx;
}

valueType* read_init_state_1(const char* filename, uint &Nx, uint &Ny, uint &n_grains) {
  FILE *inp = fopen(filename, "r");
  fscanf(inp, "%u,%u,%u", &Nx, &Ny, &n_grains);
  std::cout << "header: " << Nx << ", " << Ny << ", " << n_grains << std::endl;
  valueType* mtx = new valueType[Nx*Ny*n_grains];
  for (uint pg = 0; pg < n_grains; ++ pg) {
    for (uint x = 0; x < Nx; ++ x) {
      fscanf(inp, "%lf", mtx + pg*Nx*Ny + x*Ny);
      for (uint y = 1; y < Ny; ++ y)
        fscanf(inp, ",%lf", mtx + pg*Nx*Ny + x*Ny + y);
    }
  }
  
  for (uint pg = 0; pg < n_grains; ++pg) {
        valueType sum = 0;
        valueType *eta = mtx + pg*Nx*Ny;
        for (uint i = 0; i < Nx*Ny; ++i) {
            sum += eta[i];
            std::cout << "read eta_" << pg+1 << "_" << i << ": " << eta[i] << std::endl;
        }
        std::cout << "sum of mtx first frame eta_" << pg+1 << ": " << sum << std::endl;
    }
  fclose(inp);
  return mtx;
}

struct ReturnData {
    valueType* eta1;
    valueType* eta2;
    valueType* eta1_eta2;

    ReturnData(valueType* _eta1, valueType* _eta2, valueType* _eta1_eta2) 
        : eta1(_eta1), eta2(_eta2), eta1_eta2(_eta1_eta2) {}
};

struct ReturnLabel {
    valueType* eta1_ref;
    valueType* eta2_ref;
    valueType* eta1_eta2_ref;

    ReturnLabel(valueType* _eta1_ref, valueType* _eta2_ref, valueType* _eta1_eta2_ref) 
        : eta1_ref(_eta1_ref), eta2_ref(_eta2_ref), eta1_eta2_ref(_eta1_eta2_ref) {}
};

struct ReturnItem {
    ReturnData data;
    ReturnLabel ref;
    ReturnItem(ReturnData _rd, ReturnLabel _rl) : data(_rd), ref(_rl) {}
};

class GrainGrowthDataset{
// private:
//     valueType* all_data;
//     uint Nx, Ny, n_grains, n_step;
//     int start_skip;
//     int skip_step;
//     int cnt;

public:
    valueType* all_data;
    uint Nx, Ny, n_grains, n_step;
    int start_skip;
    int skip_step;
    int cnt;

    GrainGrowthDataset(char* data_path, int _start_skip, int _skip_step);

    ReturnItem get_item(size_t index);

    inline int get_len() {
        return cnt;
    }
};

GrainGrowthDataset::GrainGrowthDataset(char* data_path, int _start_skip, int _skip_step) {

    // FILE *inp = fopen(data_path, "r");
    // fscanf(inp, "%u,%u,%u,%u", &Nx, &Ny, &n_grains, &n_step);
    // fclose(inp);

    // all_data = read_data_file(data_path, Nx, Ny, n_grains, n_step);

    valueType* test_init_file = read_init_state(data_path, Nx, Ny, n_grains, n_step);

    test_init_file = read_init_state_1("/home/chonghao/xue_research/nanovoid_c/nanovoid/xyx/smch_debug_ver/src/grain_growth/data_128", Nx, Ny, n_grains);

    assert(1==0);

    // int _start_skip = 9;
    // start_skip = torch::from_blob(&_start_skip, {1}, torch::kInt);
    start_skip = _start_skip;
    // skip_step = torch::from_blob(&_skip_step, {1}, torch::kInt);
    skip_step = _skip_step;
    int _cnt = n_step - _start_skip * 2 - _skip_step;
    // cnt = torch::from_blob(&_cnt, {1}, torch::kInt);
    cnt = _cnt;

    // if (debug_on) {
    //     std::cout << "cv shape: " << cv.sizes() << std::endl;
    //     std::cout << "ci shape: " << ci.sizes() << std::endl;
    //     std::cout << "eta shape: " << eta.sizes() << std::endl;
    //     std::cout << "video shape: " << video.sizes() << std::endl;
    //     std::cout << "start skip tensor: " << start_skip << std::endl;
    //     std::cout << "skip step tensor: " << skip_step << std::endl;
    //     std::cout << "cnt tensor: " << cnt << std::endl;
    //     // std::cout << "cnt item value: " << cnt.item<int>() << std::endl;
    // }

}


ReturnItem GrainGrowthDataset::get_item(size_t index)
{
    if (index >= cnt) 
        return ReturnItem(ReturnData(nullptr, nullptr, nullptr), ReturnLabel(nullptr, nullptr, nullptr));
    // prepare data
    valueType* eta1_eta2 = all_data + index*n_grains*Nx*Ny;
    valueType* eta1 = eta1_eta2;
    valueType* eta2 = eta1 + Nx*Ny;

    // prepare label
    size_t index_ref = index + skip_step;
    valueType* eta1_eta2_ref = all_data + index_ref*n_grains*Nx*Ny;
    valueType* eta1_ref = eta1_eta2_ref;
    valueType* eta2_ref = eta1_ref + Nx*Ny;
    
    ReturnData return_data(eta1, eta2, eta1_eta2);
    ReturnLabel return_label(eta1_ref, eta2_ref, eta1_eta2_ref);

    ReturnItem return_item(return_data, return_label);
    return return_item;
}


